{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[1. , 2.1],\n",
       "        [2. , 1.1],\n",
       "        [1.3, 1. ],\n",
       "        [1. , 1. ],\n",
       "        [2. , 1. ]]),\n",
       " array([ 1,  1, -1, -1,  1]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def load_data():\n",
    "    x = np.array([[1.0, 2.1], [2.0, 1.1], [1.3, 1.0], [1.0, 1.0], [2.0, 1.0]])\n",
    "    y = np.array([1, 1, -1, -1, 1])\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.2, 0.2, 0.2, 0.2, 0.2])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = len(x)\n",
    "\n",
    "#数据权重,在初始化时,认为所有的数据都是等价的\n",
    "D = np.empty(5)\n",
    "D.fill(1 / N)\n",
    "D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tree{col=0,value=1.00,eq=<,weight=1.00}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Tree():\n",
    "    def __init__(self, col, value, eq):\n",
    "        self.col = col\n",
    "        self.value = value\n",
    "        self.eq = eq\n",
    "        #权重\n",
    "        self.weight = 1\n",
    "\n",
    "    #预测方法,简单的根据某个值分割数据\n",
    "    def __call__(self, xi):\n",
    "        if self.eq == '<':\n",
    "            if xi[self.col] < self.value:\n",
    "                return 1\n",
    "            return -1\n",
    "\n",
    "        if self.eq == '>':\n",
    "            if xi[self.col] >= self.value:\n",
    "                return 1\n",
    "            return -1\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Tree{col=%d,value=%.2f,eq=%s,weight=%.2f}' % (\n",
    "            self.col, self.value, self.eq, self.weight)\n",
    "\n",
    "\n",
    "tree = Tree(0, 1, '<')\n",
    "print(tree)\n",
    "tree(x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6000000000000001"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算树的loss,考虑数据的权重,越重要的数据,惩罚的越严重\n",
    "def get_loss(tree):\n",
    "    loss = 0\n",
    "    for xi, yi, di in zip(x, y, D):\n",
    "        pred = tree(xi)\n",
    "        if pred != yi:\n",
    "            loss += di\n",
    "    return loss\n",
    "\n",
    "\n",
    "get_loss(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tree{col=0,value=1.30,eq=>,weight=1.00}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.2"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#训练一棵树,总的来说,就是求loss最小\n",
    "def get_tree():\n",
    "\n",
    "    min_loss = np.inf\n",
    "\n",
    "    min_col = 0\n",
    "    min_value = 0\n",
    "    min_eq = '<'\n",
    "\n",
    "    min_loss_tree = None\n",
    "\n",
    "    #遍历所有列\n",
    "    for col in range(x.shape[1]):\n",
    "\n",
    "        #遍历符号\n",
    "        for eq in ['<', '>']:\n",
    "\n",
    "            #从 列最小-0.1 遍历到 列最大+0.1\n",
    "            col_min = x[:, col].min() - 0.1\n",
    "            col_max = x[:, col].max() + 0.1\n",
    "\n",
    "            value = col_min\n",
    "\n",
    "            #遍历value值\n",
    "            while value < col_max:\n",
    "                tree = Tree(col, value, eq)\n",
    "                loss = get_loss(tree)\n",
    "\n",
    "                if loss < min_loss:\n",
    "                    min_loss = loss\n",
    "                    min_tree = tree\n",
    "\n",
    "                value += 0.1\n",
    "\n",
    "    return min_tree\n",
    "\n",
    "\n",
    "tree = get_tree()\n",
    "print(tree)\n",
    "get_loss(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6931471805599453"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算树的权重\n",
    "def get_tree_weight(tree):\n",
    "    #在当前数据权重的情况下,计算loss\n",
    "    loss = get_loss(tree)\n",
    "\n",
    "    #计算权重,这是一个恒正的数,loss约低,权重越大\n",
    "    #防止分母为0\n",
    "    weight = (1 - loss) / max(loss, 1e-5)\n",
    "\n",
    "    #取对数,防止小数连乘\n",
    "    return np.log(weight) / 2\n",
    "\n",
    "\n",
    "tree.weight = get_tree_weight(tree)\n",
    "tree.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5  , 0.125, 0.125, 0.125, 0.125])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#计算数据的权重\n",
    "def get_D(tree):\n",
    "\n",
    "    new_D = np.empty(N)\n",
    "\n",
    "    for i in range(N):\n",
    "        #如果预测结果不正确,则增加数据的权重.预测正确,则减小数据的权重,当然也要考虑树本身的权重\n",
    "        temp = tree.weight * -y[i] * tree(x[i])\n",
    "\n",
    "        #取exp,可以认为是转换为了百分比,负数取exp,是一个小于1的数, 正数取exp, 是一个大于1的数.当然,exp是恒正的\n",
    "        temp = np.exp(temp)\n",
    "\n",
    "        #让D在上面的的计算结果上伸展\n",
    "        new_D[i] = D[i] * temp\n",
    "\n",
    "    #归一化\n",
    "    new_D = new_D / new_D.sum()\n",
    "\n",
    "    return new_D\n",
    "\n",
    "\n",
    "get_D(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tree{col=0,value=1.30,eq=>,weight=0.69}\n",
      "Tree{col=1,value=1.10,eq=>,weight=0.97}\n",
      "Tree{col=0,value=2.00,eq=<,weight=0.90}\n",
      "Tree{col=0,value=1.30,eq=>,weight=0.80}\n",
      "Tree{col=1,value=1.10,eq=>,weight=0.78}\n",
      "Tree{col=0,value=2.00,eq=<,weight=0.75}\n",
      "Tree{col=0,value=1.30,eq=>,weight=0.74}\n",
      "Tree{col=1,value=1.10,eq=>,weight=0.73}\n",
      "Tree{col=0,value=2.00,eq=<,weight=0.73}\n",
      "Tree{col=0,value=1.30,eq=>,weight=0.73}\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    global D\n",
    "\n",
    "    trees = []\n",
    "    #训练10棵树\n",
    "    for i in range(10):\n",
    "        tree = get_tree()\n",
    "        tree.weight = get_tree_weight(tree)\n",
    "        trees.append(tree)\n",
    "\n",
    "        #重新计算数据权重\n",
    "        D = get_D(tree)\n",
    "\n",
    "    return trees\n",
    "\n",
    "\n",
    "#重新初始化数据权重\n",
    "D.fill(1 / N)\n",
    "D\n",
    "\n",
    "trees = train()\n",
    "for i in trees:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#预测函数\n",
    "def prediction(trees, xi):\n",
    "\n",
    "    #就是累加树的权重*树的结果,因为前面取了对数权重了,所以这里用加号而不是乘号\n",
    "    pred = 0\n",
    "    for tree in trees:\n",
    "        pred += tree.weight * tree(xi)\n",
    "\n",
    "    #判断正负\n",
    "    pred = np.sign(pred)\n",
    "    return pred\n",
    "\n",
    "\n",
    "prediction(trees, x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#测试\n",
    "correct = 0\n",
    "for xi, yi in zip(x, y):\n",
    "    pred = prediction(trees, xi)\n",
    "    if pred == yi:\n",
    "        correct += 1\n",
    "\n",
    "correct / N"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
